#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 20 11:09:03 2018

@author: annabelle

"""
import pickle
import os
from stringCore import *


def findBestWindow(amsb):
    mzs_system = []

    # return images ready for postselection and string pattern detection
    ms = []
    mDiffs = []
    mzs_window = []

    for img in amsb:
        m = img.copy()

        roi = np.abs(m)
        totalSites = np.sum(roi)

        # get staggered magnetization
        sm = 0;
        for i in range(m.shape[0]):
            for j in range(m.shape[1]):
                sm += (-1) ** (i + j) * m[i, j]
        mzs_system.append(sm * 1.0 / totalSites)

        m_temps = []
        smz_temps = []

        for x0 in range(1, m.shape[0]-7):  # 7 is hardcoded in because it's the width of the window
            for y0 in range(1, m.shape[1]-7):  # 7 is hardcoded in because it's the height of the window
                mask = np.zeros_like(m)

                try:
                    mask[x0 + 0, y0 + 2:y0 + 5] = 1
                    mask[x0 + 1, y0 + 1:y0 + 6] = 1
                    mask[x0 + 2, y0 + 0:y0 + 7] = 1
                    mask[x0 + 3] = mask[x0 + 2]
                    mask[x0 + 4] = mask[x0 + 2]
                    mask[x0 + 5] = mask[x0 + 1]
                    mask[x0 + 6] = mask[x0]
                except IndexError:
                    continue

                m1 = m.copy()
                m1[mask == 0] = 0
                totalSites_mask = np.sum(mask)
                totalSites = np.sum(abs(m1))
                if totalSites == totalSites_mask:  # only take position if window completely in big ROI
                    # get staggered magnetization
                    sm = 0
                    for i in range(m1.shape[0]):
                        for j in range(m1.shape[1]):
                            sm += (-1) ** (i + j) * m1[i, j]
                    m_temps.append(m1)
                    smz_temps.append(sm)

            indices = sorted(range(len(smz_temps)), key=lambda k: -abs(smz_temps[k]))
            m1 = m_temps[indices[0]]

        # cut out window from system
        roi = np.where(m1 != 0)
        # find corners of big roi
        xMin = min(np.unique(roi[0]))
        xMax = max(np.unique(roi[0]))
        yMin = min(np.unique(roi[1]))
        yMax = max(np.unique(roi[1]))

        m1 = m1[xMin:xMax + 1, yMin:yMax + 1]

        # get staggered magnetization of small region
        mz = 0
        for i in range(m1.shape[0]):
            for j in range(m1.shape[1]):
                mz += (-1) ** (i + j) * m1[i, j]
        mzs_window.append(mz)

        # initiate reference Neel state
        ref = np.ones_like(m1)
        for i in range(m1.shape[0]):
            for j in range(m1.shape[1]):
                ref[i, j] = (-1) ** (i + j) * np.sign(mz)
        ref[m1 == 0] = 0

        ms.append(m1)
        mDiffs.append(abs(m1 - ref))

    return mzs_system, ms, mDiffs, mzs_window


def shotEval(dir, case, gethist, ps, label, detect='redAreas'):
    # This is a really longwinded way to get cold, hot, D
    with open("V:\Paper Data\AFM_StringTheory\Snapshots_experiment\{}_params.pkl".format(
            ''.join(i for i in case.split('_')[-1] if not i.isdigit())), 'rb') as f:
        params = pickle.load(f)

    rounded_dopings = [int(round(d / 2.) * 2) for d in params['doping']]

    dopings = []
    dopingerrs = []
    scs = []
    sces = []

    for fn in os.listdir(dir):
        if fn.startswith(case) and not fn.endswith('params.pkl'):  # filter for correct dataset
            print "Evaluating {} for {} {}".format(fn, case, label)

            rounddope = int(fn.split('_')[-1][1:-4])

            pdopes = params['doping'][str(rounddope)] if 'D' in case else params['doping']
            pdopeerrs = params['dopingerr'][str(rounddope)] if 'D' in case else params['dopingerr']

            # if simulated data, use the half-filling doping error, but the exact doping value
            if len(case.split('_')) > 1:
                dopings.append(rounddope)
                dopingerrs.append(pdopeerrs[rounded_dopings.index(0)])
            # if experiment
            else:
                dopings.append(pdopes[rounded_dopings.index(rounddope)])
                dopingerrs.append(pdopeerrs[rounded_dopings.index(rounddope)])

            print "  Doping      {} +/- {}".format(dopings[-1], dopingerrs[-1])

            with open(dir+fn, 'rb') as f:
                ams = pickle.load(f)

            amsb = list(ams[1]) + list(ams[2])

            mzs_system, ms, mDiffs, mzs_window = findBestWindow(amsb)

            print '  Sorting snapshots!'
            indices = sorted(range(len(mzs_window)), key=lambda k: -abs(mzs_window[k]))
            mzs_window = [mzs_window[i] for i in indices]
            mDiffs = [mDiffs[i] for i in indices]
            ms = [ms[i] for i in indices]

            numberShots = int(ps * len(mzs_window) + 0.5)
            print '    Keeping best {} snapshots of all {}'.format(numberShots, len(mDiffs))
            print '    Lowest staggered magnetization accepted: ' + str(abs(mzs_window[numberShots - 1]))

            lengthFCS = []
            print '  Extracting string patterns!'
            for l in range(numberShots):
                if l % 500 == 0:
                    print '    ...searching shot no. ' + str(l)

                mDiff = np.copy(mDiffs[l])
                m = np.copy(ms[l])
                stringLengths_temp = []

                # find all red sites initially
                redSites = np.transpose(np.where(mDiff == 2)).tolist()

                # group all red sites into red areas, and from these look for and undo string patterns
                while len(redSites) > 0:
                    while len(redSites) > 0:
                        red = redSites[0]
                        redArea = findRedArea(red, [], redSites, mDiff, mDiff.shape[0])

                        for redSite in redArea:
                            redSites.remove(redSite)

                        stringLengths_temp.append(len(redArea)-1)
                        for red in redArea:
                            mDiff[red[0], red[1]] = 0

                    redSites = np.transpose(np.where(mDiff == 2)).tolist()  # find all remaining red sites and repeat

                lMax = 8
                lengthFCS.append([stringLengths_temp.count(i) * 1.0/np.count_nonzero(m) for i in range(lMax+1)])

            # get string length distribution
            lengthFCS = np.transpose(lengthFCS)
            lengthFCS_mean = []
            lengthFCS_std = []
            for i in range(lMax + 1):
                lengthFCS_mean.append(np.mean(lengthFCS[i]))
                lengthFCS_std.append(np.std(lengthFCS[i]) * 1.0 / np.sqrt(numberShots - 1))
            lengthFCS_mean = np.array(lengthFCS_mean)
            lengthFCS_std = np.array(lengthFCS_std)

            # get string count
            scs.append(np.sum(lengthFCS_mean[2:]))
            sces.append(np.sqrt(np.sum(lengthFCS_std[2:]**2)))

            print '    String count: ', np.sum(lengthFCS_mean[2:]), np.sqrt(np.sum(lengthFCS_std[2:]**2))

            if gethist:
                FCSdict = {'lengthFCS_mean': lengthFCS_mean,
                           'lengthFCS_std': lengthFCS_std,
                           'doping': dopings[-1],
                           'dopingerr': dopingerrs[-1]}
                with open("V:\Paper Data\AFM_StringTheory\FCS\\" + fn[:-4] + label + '_redAreas_hist.pkl', 'wb') as f:
                    pickle.dump(FCSdict, f)

    # only output averages vs. doping if there are many doping or temperature values, say more than 3
    if len(dopings) > 3:
        # reorder values by doping; here there is just one temperature
        scs = [x for _, x in sorted(zip(dopings, scs))]
        sces = [x for _, x in sorted(zip(dopings, sces))]
        dopingerrs = [x for _, x in sorted(zip(dopings, dopingerrs))]
        dopings = sorted(dopings)

        FCSdict = {'doping': np.array(dopings),
                   'dopingerr': np.array(dopingerrs),
                   'sc': np.array(scs),
                   'sces': np.array(sces)}

        with open("V:\Paper Data\AFM_StringTheory\FCS\\" + case + label + '_redAreas_avgs.pkl', 'wb') as f:
            pickle.dump(FCSdict, f)


if __name__ == '__main__':
    dir = "V:\\Paper Data\AFM_StringTheory\\"
    ps = 0.6

    # Fig S3 data
    dis = ['Snapshots_experiment\\', 'Snapshots_sprinkled\\', 'Snapshots_strings\\']
    cases = ['cold', 'peak0_cold', 'gst0.60J_cold']
    gethist = True
    for di, case in zip(dis, cases):
        shotEval(dir+di, case, gethist, ps, label='')

